{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 7. Neural Machine Translation and Models with Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend you take a look at these material first."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n",
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture10.pdf\n",
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture11.pdf\n",
    "* https://arxiv.org/pdf/1409.0473.pdf\n",
    "* https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb\n",
    "* https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983\n",
    "* http://www.manythings.org/anki/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter, OrderedDict\n",
    "import nltk\n",
    "from copy import deepcopy\n",
    "import os\n",
    "import re\n",
    "import unicodedata\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "\n",
    "from torch.nn.utils.rnn import PackedSequence,pack_padded_sequence\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "random.seed(1024)\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(batch_size, train_data):\n",
    "    random.shuffle(train_data)\n",
    "    sindex=0\n",
    "    eindex=batch_size\n",
    "    while eindex < len(train_data):\n",
    "        batch = train_data[sindex: eindex]\n",
    "        temp = eindex\n",
    "        eindex = eindex + batch_size\n",
    "        sindex = temp\n",
    "        yield batch\n",
    "    \n",
    "    if eindex >= len(train_data):\n",
    "        batch = train_data[sindex:]\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Padding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/07.pad_to_sequence.png\">\n",
    "<center>borrowed image from https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# It is for Sequence 2 Sequence format\n",
    "def pad_to_batch(batch, x_to_ix, y_to_ix):\n",
    "    \n",
    "    sorted_batch =  sorted(batch, key=lambda b:b[0].size(1), reverse=True) # sort by len\n",
    "    x,y = list(zip(*sorted_batch))\n",
    "    max_x = max([s.size(1) for s in x])\n",
    "    max_y = max([s.size(1) for s in y])\n",
    "    x_p, y_p = [], []\n",
    "    for i in range(len(batch)):\n",
    "        if x[i].size(1) < max_x:\n",
    "            x_p.append(torch.cat([x[i], Variable(LongTensor([x_to_ix['<PAD>']] * (max_x - x[i].size(1)))).view(1, -1)], 1))\n",
    "        else:\n",
    "            x_p.append(x[i])\n",
    "        if y[i].size(1) < max_y:\n",
    "            y_p.append(torch.cat([y[i], Variable(LongTensor([y_to_ix['<PAD>']] * (max_y - y[i].size(1)))).view(1, -1)], 1))\n",
    "        else:\n",
    "            y_p.append(y[i])\n",
    "        \n",
    "    input_var = torch.cat(x_p)\n",
    "    target_var = torch.cat(y_p)\n",
    "    input_len = [list(map(lambda s: s ==0, t.data)).count(False) for t in input_var]\n",
    "    target_len = [list(map(lambda s: s ==0, t.data)).count(False) for t in target_var]\n",
    "    \n",
    "    return input_var, target_var, input_len, target_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_sequence(seq, to_index):\n",
    "    idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"<UNK>\"], seq))\n",
    "    return Variable(LongTensor(idxs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load and Preprocessing "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Borrowed code from https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427\n",
    "def unicode_to_ascii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "# Lowercase, trim, and remove non-letter characters\n",
    "def normalize_string(s):\n",
    "    s = unicode_to_ascii(s.lower().strip())\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1 \", s)\n",
    "    s = re.sub(r\"[^a-zA-Z,.!?]+\", r\" \", s)\n",
    "    s = re.sub(r\"\\s+\", r\" \", s).strip()\n",
    "    return s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center><h3>French -> English</h3></center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "corpus = open('../dataset/eng-fra.txt', 'r', encoding='utf-8').readlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "142787"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "corpus = corpus[:30000] # for practice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "MIN_LENGTH = 3\n",
    "MAX_LENGTH = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29830 29830\n",
      "['i', 'see', '.'] ['je', 'comprends', '.']\n",
      "CPU times: user 836 ms, sys: 8 ms, total: 844 ms\n",
      "Wall time: 843 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "X_r, y_r = [], [] # raw\n",
    "\n",
    "for parallel in corpus:\n",
    "    so,ta = parallel[:-1].split('\\t')\n",
    "    if so.strip() == \"\" or ta.strip() == \"\": \n",
    "        continue\n",
    "    \n",
    "    normalized_so = normalize_string(so).split()\n",
    "    normalized_ta = normalize_string(ta).split()\n",
    "    \n",
    "    if len(normalized_so) >= MIN_LENGTH and len(normalized_so) <= MAX_LENGTH \\\n",
    "    and len(normalized_ta) >= MIN_LENGTH and len(normalized_ta) <= MAX_LENGTH:\n",
    "        X_r.append(normalized_so)\n",
    "        y_r.append(normalized_ta)\n",
    "    \n",
    "\n",
    "print(len(X_r), len(y_r))\n",
    "print(X_r[0], y_r[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4433 7704\n"
     ]
    }
   ],
   "source": [
    "source_vocab = list(set(flatten(X_r)))\n",
    "target_vocab = list(set(flatten(y_r)))\n",
    "print(len(source_vocab), len(target_vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "source2index = {'<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3}\n",
    "for vo in source_vocab:\n",
    "    if source2index.get(vo) is None:\n",
    "        source2index[vo] = len(source2index)\n",
    "index2source = {v:k for k, v in source2index.items()}\n",
    "\n",
    "target2index = {'<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3}\n",
    "for vo in target_vocab:\n",
    "    if target2index.get(vo) is None:\n",
    "        target2index[vo] = len(target2index)\n",
    "index2target = {v:k for k, v in target2index.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare train data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.16 s, sys: 364 ms, total: 2.52 s\n",
      "Wall time: 2.89 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "X_p, y_p = [], []\n",
    "\n",
    "for so, ta in zip(X_r, y_r):\n",
    "    X_p.append(prepare_sequence(so + ['</s>'], source2index).view(1, -1))\n",
    "    y_p.append(prepare_sequence(ta + ['</s>'], target2index).view(1, -1))\n",
    "    \n",
    "train_data = list(zip(X_p, y_p))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/07.seq2seq.png\">\n",
    "<center>borrowd image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture10.pdf</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're not familier with <strong>pack_padded_sequence</strong> and <strong>pad_packed_sequence</strong>, check this <a href=\"https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983\">post</a>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_size, embedding_size,hidden_size, n_layers=1,bidirec=False):\n",
    "        super(Encoder, self).__init__()\n",
    "        \n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.n_layers = n_layers\n",
    "        \n",
    "        self.embedding = nn.Embedding(input_size, embedding_size)\n",
    "        \n",
    "        if bidirec:\n",
    "            self.n_direction = 2 \n",
    "            self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True, bidirectional=True)\n",
    "        else:\n",
    "            self.n_direction = 1\n",
    "            self.gru = nn.GRU(embedding_size, hidden_size, n_layers, batch_first=True)\n",
    "    \n",
    "    def init_hidden(self, inputs):\n",
    "        hidden = Variable(torch.zeros(self.n_layers * self.n_direction, inputs.size(0), self.hidden_size))\n",
    "        return hidden.cuda() if USE_CUDA else hidden\n",
    "    \n",
    "    def init_weight(self):\n",
    "        self.embedding.weight = nn.init.xavier_uniform(self.embedding.weight)\n",
    "        self.gru.weight_hh_l0 = nn.init.xavier_uniform(self.gru.weight_hh_l0)\n",
    "        self.gru.weight_ih_l0 = nn.init.xavier_uniform(self.gru.weight_ih_l0)\n",
    "    \n",
    "    def forward(self, inputs, input_lengths):\n",
    "        \"\"\"\n",
    "        inputs : B, T (LongTensor)\n",
    "        input_lengths : real lengths of input batch (list)\n",
    "        \"\"\"\n",
    "        hidden = self.init_hidden(inputs)\n",
    "        \n",
    "        embedded = self.embedding(inputs)\n",
    "        packed = pack_padded_sequence(embedded, input_lengths, batch_first=True)\n",
    "        outputs, hidden = self.gru(packed, hidden)\n",
    "        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) # unpack (back to padded)\n",
    "                \n",
    "        if self.n_layers > 1:\n",
    "            if self.n_direction == 2:\n",
    "                hidden = hidden[-2:]\n",
    "            else:\n",
    "                hidden = hidden[-1]\n",
    "        \n",
    "        return outputs, torch.cat([h for h in hidden], 1).unsqueeze(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attention Mechanism ( https://arxiv.org/pdf/1409.0473.pdf )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I used general-type for score function $h_t^TW_ah_s^-$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/07.attention-mechanism.png\">\n",
    "<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture10.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, input_size, embedding_size, hidden_size, n_layers=1, dropout_p=0.1):\n",
    "        super(Decoder, self).__init__()\n",
    "        \n",
    "        self.hidden_size = hidden_size\n",
    "        self.n_layers = n_layers\n",
    "        \n",
    "        # Define the layers\n",
    "        self.embedding = nn.Embedding(input_size, embedding_size)\n",
    "        self.dropout = nn.Dropout(dropout_p)\n",
    "        \n",
    "        self.gru = nn.GRU(embedding_size + hidden_size, hidden_size, n_layers, batch_first=True)\n",
    "        self.linear = nn.Linear(hidden_size * 2, input_size)\n",
    "        self.attn = nn.Linear(self.hidden_size, self.hidden_size) # Attention\n",
    "    \n",
    "    def init_hidden(self,inputs):\n",
    "        hidden = Variable(torch.zeros(self.n_layers, inputs.size(0), self.hidden_size))\n",
    "        return hidden.cuda() if USE_CUDA else hidden\n",
    "    \n",
    "    \n",
    "    def init_weight(self):\n",
    "        self.embedding.weight = nn.init.xavier_uniform(self.embedding.weight)\n",
    "        self.gru.weight_hh_l0 = nn.init.xavier_uniform(self.gru.weight_hh_l0)\n",
    "        self.gru.weight_ih_l0 = nn.init.xavier_uniform(self.gru.weight_ih_l0)\n",
    "        self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n",
    "        self.attn.weight = nn.init.xavier_uniform(self.attn.weight)\n",
    "#         self.attn.bias.data.fill_(0)\n",
    "    \n",
    "    def Attention(self, hidden, encoder_outputs, encoder_maskings):\n",
    "        \"\"\"\n",
    "        hidden : 1,B,D\n",
    "        encoder_outputs : B,T,D\n",
    "        encoder_maskings : B,T # ByteTensor\n",
    "        \"\"\"\n",
    "        hidden = hidden[0].unsqueeze(2)  # (1,B,D) -> (B,D,1)\n",
    "        \n",
    "        batch_size = encoder_outputs.size(0) # B\n",
    "        max_len = encoder_outputs.size(1) # T\n",
    "        energies = self.attn(encoder_outputs.contiguous().view(batch_size * max_len, -1)) # B*T,D -> B*T,D\n",
    "        energies = energies.view(batch_size,max_len, -1) # B,T,D\n",
    "        attn_energies = energies.bmm(hidden).squeeze(2) # B,T,D * B,D,1 --> B,T\n",
    "        \n",
    "#         if isinstance(encoder_maskings,torch.autograd.variable.Variable):\n",
    "#             attn_energies = attn_energies.masked_fill(encoder_maskings,float('-inf'))#-1e12) # PAD masking\n",
    "        \n",
    "        alpha = F.softmax(attn_energies,1) # B,T\n",
    "        alpha = alpha.unsqueeze(1) # B,1,T\n",
    "        context = alpha.bmm(encoder_outputs) # B,1,T * B,T,D => B,1,D\n",
    "        \n",
    "        return context, alpha\n",
    "    \n",
    "    \n",
    "    def forward(self, inputs, context, max_length, encoder_outputs, encoder_maskings=None, is_training=False):\n",
    "        \"\"\"\n",
    "        inputs : B,1 (LongTensor, START SYMBOL)\n",
    "        context : B,1,D (FloatTensor, Last encoder hidden state)\n",
    "        max_length : int, max length to decode # for batch\n",
    "        encoder_outputs : B,T,D\n",
    "        encoder_maskings : B,T # ByteTensor\n",
    "        is_training : bool, this is because adapt dropout only training step.\n",
    "        \"\"\"\n",
    "        # Get the embedding of the current input word\n",
    "        embedded = self.embedding(inputs)\n",
    "        hidden = self.init_hidden(inputs)\n",
    "        if is_training:\n",
    "            embedded = self.dropout(embedded)\n",
    "        \n",
    "        decode = []\n",
    "        # Apply GRU to the output so far\n",
    "        for i in range(max_length):\n",
    "\n",
    "            _, hidden = self.gru(torch.cat((embedded, context), 2), hidden) # h_t = f(h_{t-1},y_{t-1},c)\n",
    "            concated = torch.cat((hidden, context.transpose(0, 1)), 2) # y_t = g(h_t,y_{t-1},c)\n",
    "            score = self.linear(concated.squeeze(0))\n",
    "            softmaxed = F.log_softmax(score,1)\n",
    "            decode.append(softmaxed)\n",
    "            decoded = softmaxed.max(1)[1]\n",
    "            embedded = self.embedding(decoded).unsqueeze(1) # y_{t-1}\n",
    "            if is_training:\n",
    "                embedded = self.dropout(embedded)\n",
    "            \n",
    "            # compute next context vector using attention\n",
    "            context, alpha = self.Attention(hidden, encoder_outputs, encoder_maskings)\n",
    "            \n",
    "        #  column-wise concat, reshape!!\n",
    "        scores = torch.cat(decode, 1)\n",
    "        return scores.view(inputs.size(0) * max_length, -1)\n",
    "    \n",
    "    def decode(self, context, encoder_outputs):\n",
    "        start_decode = Variable(LongTensor([[target2index['<s>']] * 1])).transpose(0, 1)\n",
    "        embedded = self.embedding(start_decode)\n",
    "        hidden = self.init_hidden(start_decode)\n",
    "        \n",
    "        decodes = []\n",
    "        attentions = []\n",
    "        decoded = embedded\n",
    "        while decoded.data.tolist()[0] != target2index['</s>']: # until </s>\n",
    "            _, hidden = self.gru(torch.cat((embedded, context), 2), hidden) # h_t = f(h_{t-1},y_{t-1},c)\n",
    "            concated = torch.cat((hidden, context.transpose(0, 1)), 2) # y_t = g(h_t,y_{t-1},c)\n",
    "            score = self.linear(concated.squeeze(0))\n",
    "            softmaxed = F.log_softmax(score,1)\n",
    "            decodes.append(softmaxed)\n",
    "            decoded = softmaxed.max(1)[1]\n",
    "            embedded = self.embedding(decoded).unsqueeze(1) # y_{t-1}\n",
    "            context, alpha = self.Attention(hidden, encoder_outputs,None)\n",
    "            attentions.append(alpha.squeeze(1))\n",
    "        \n",
    "        return torch.cat(decodes).max(1)[1], torch.cat(attentions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It takes for a while if you use just cpu...."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "EPOCH = 50\n",
    "BATCH_SIZE = 64\n",
    "EMBEDDING_SIZE = 300\n",
    "HIDDEN_SIZE = 512\n",
    "LR = 0.001\n",
    "DECODER_LEARNING_RATIO = 5.0\n",
    "RESCHEDULED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "encoder = Encoder(len(source2index), EMBEDDING_SIZE, HIDDEN_SIZE, 3, True)\n",
    "decoder = Decoder(len(target2index), EMBEDDING_SIZE, HIDDEN_SIZE * 2)\n",
    "encoder.init_weight()\n",
    "decoder.init_weight()\n",
    "\n",
    "if USE_CUDA:\n",
    "    encoder = encoder.cuda()\n",
    "    decoder = decoder.cuda()\n",
    "\n",
    "loss_function = nn.CrossEntropyLoss(ignore_index=0)\n",
    "enc_optimizer = optim.Adam(encoder.parameters(), lr=LR)\n",
    "dec_optimizer = optim.Adam(decoder.parameters(), lr=LR * DECODER_LEARNING_RATIO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[00/50] [000/466] mean_loss : 8.94\n",
      "[00/50] [200/466] mean_loss : 4.27\n",
      "[00/50] [400/466] mean_loss : 3.38\n",
      "[01/50] [000/466] mean_loss : 2.79\n",
      "[01/50] [200/466] mean_loss : 2.59\n",
      "[01/50] [400/466] mean_loss : 2.40\n",
      "[02/50] [000/466] mean_loss : 1.86\n",
      "[02/50] [200/466] mean_loss : 1.97\n",
      "[02/50] [400/466] mean_loss : 1.93\n",
      "[03/50] [000/466] mean_loss : 1.58\n",
      "[03/50] [200/466] mean_loss : 1.61\n",
      "[03/50] [400/466] mean_loss : 1.70\n",
      "[04/50] [000/466] mean_loss : 1.42\n",
      "[04/50] [200/466] mean_loss : 1.43\n",
      "[04/50] [400/466] mean_loss : 1.48\n",
      "[05/50] [000/466] mean_loss : 1.09\n",
      "[05/50] [200/466] mean_loss : 1.29\n",
      "[05/50] [400/466] mean_loss : 1.31\n",
      "[06/50] [000/466] mean_loss : 1.00\n",
      "[06/50] [200/466] mean_loss : 1.17\n",
      "[06/50] [400/466] mean_loss : 1.23\n",
      "[07/50] [000/466] mean_loss : 0.82\n",
      "[07/50] [200/466] mean_loss : 1.07\n",
      "[07/50] [400/466] mean_loss : 1.14\n",
      "[08/50] [000/466] mean_loss : 1.10\n",
      "[08/50] [200/466] mean_loss : 1.02\n",
      "[08/50] [400/466] mean_loss : 1.09\n",
      "[09/50] [000/466] mean_loss : 0.86\n",
      "[09/50] [200/466] mean_loss : 0.96\n",
      "[09/50] [400/466] mean_loss : 1.04\n",
      "[10/50] [000/466] mean_loss : 0.85\n",
      "[10/50] [200/466] mean_loss : 0.90\n",
      "[10/50] [400/466] mean_loss : 0.97\n",
      "[11/50] [000/466] mean_loss : 0.97\n",
      "[11/50] [200/466] mean_loss : 0.88\n",
      "[11/50] [400/466] mean_loss : 0.95\n",
      "[12/50] [000/466] mean_loss : 0.77\n",
      "[12/50] [200/466] mean_loss : 0.86\n",
      "[12/50] [400/466] mean_loss : 0.92\n",
      "[13/50] [000/466] mean_loss : 0.70\n",
      "[13/50] [200/466] mean_loss : 0.80\n",
      "[13/50] [400/466] mean_loss : 0.87\n",
      "[14/50] [000/466] mean_loss : 0.66\n",
      "[14/50] [200/466] mean_loss : 0.74\n",
      "[14/50] [400/466] mean_loss : 0.84\n",
      "[15/50] [000/466] mean_loss : 0.66\n",
      "[15/50] [200/466] mean_loss : 0.72\n",
      "[15/50] [400/466] mean_loss : 0.81\n",
      "[16/50] [000/466] mean_loss : 0.55\n",
      "[16/50] [200/466] mean_loss : 0.72\n",
      "[16/50] [400/466] mean_loss : 0.80\n",
      "[17/50] [000/466] mean_loss : 0.64\n",
      "[17/50] [200/466] mean_loss : 0.70\n",
      "[17/50] [400/466] mean_loss : 0.80\n",
      "[18/50] [000/466] mean_loss : 0.62\n",
      "[18/50] [200/466] mean_loss : 0.69\n",
      "[18/50] [400/466] mean_loss : 0.77\n",
      "[19/50] [000/466] mean_loss : 0.49\n",
      "[19/50] [200/466] mean_loss : 0.74\n",
      "[19/50] [400/466] mean_loss : 0.80\n",
      "[20/50] [000/466] mean_loss : 0.55\n",
      "[20/50] [200/466] mean_loss : 0.67\n",
      "[20/50] [400/466] mean_loss : 0.76\n",
      "[21/50] [000/466] mean_loss : 0.64\n",
      "[21/50] [200/466] mean_loss : 0.67\n",
      "[21/50] [400/466] mean_loss : 0.75\n",
      "[22/50] [000/466] mean_loss : 0.60\n",
      "[22/50] [200/466] mean_loss : 0.63\n",
      "[22/50] [400/466] mean_loss : 0.70\n",
      "[23/50] [000/466] mean_loss : 0.60\n",
      "[23/50] [200/466] mean_loss : 0.60\n",
      "[23/50] [400/466] mean_loss : 0.67\n",
      "[24/50] [000/466] mean_loss : 0.57\n",
      "[24/50] [200/466] mean_loss : 0.61\n",
      "[24/50] [400/466] mean_loss : 0.68\n",
      "[25/50] [000/466] mean_loss : 0.50\n",
      "[25/50] [200/466] mean_loss : 0.61\n",
      "[25/50] [400/466] mean_loss : 0.68\n",
      "[26/50] [000/466] mean_loss : 0.53\n",
      "[26/50] [200/466] mean_loss : 0.53\n",
      "[26/50] [400/466] mean_loss : 0.51\n",
      "[27/50] [000/466] mean_loss : 0.58\n",
      "[27/50] [200/466] mean_loss : 0.50\n",
      "[27/50] [400/466] mean_loss : 0.49\n",
      "[28/50] [000/466] mean_loss : 0.40\n",
      "[28/50] [200/466] mean_loss : 0.48\n",
      "[28/50] [400/466] mean_loss : 0.47\n",
      "[29/50] [000/466] mean_loss : 0.45\n",
      "[29/50] [200/466] mean_loss : 0.46\n",
      "[29/50] [400/466] mean_loss : 0.45\n",
      "[30/50] [000/466] mean_loss : 0.56\n",
      "[30/50] [200/466] mean_loss : 0.44\n",
      "[30/50] [400/466] mean_loss : 0.45\n",
      "[31/50] [000/466] mean_loss : 0.46\n",
      "[31/50] [200/466] mean_loss : 0.43\n",
      "[31/50] [400/466] mean_loss : 0.43\n",
      "[32/50] [000/466] mean_loss : 0.30\n",
      "[32/50] [200/466] mean_loss : 0.41\n",
      "[32/50] [400/466] mean_loss : 0.42\n",
      "[33/50] [000/466] mean_loss : 0.30\n",
      "[33/50] [200/466] mean_loss : 0.40\n",
      "[33/50] [400/466] mean_loss : 0.41\n",
      "[34/50] [000/466] mean_loss : 0.34\n",
      "[34/50] [200/466] mean_loss : 0.40\n",
      "[34/50] [400/466] mean_loss : 0.40\n",
      "[35/50] [000/466] mean_loss : 0.32\n",
      "[35/50] [200/466] mean_loss : 0.39\n",
      "[35/50] [400/466] mean_loss : 0.39\n",
      "[36/50] [000/466] mean_loss : 0.31\n",
      "[36/50] [200/466] mean_loss : 0.39\n",
      "[36/50] [400/466] mean_loss : 0.38\n",
      "[37/50] [000/466] mean_loss : 0.39\n",
      "[37/50] [200/466] mean_loss : 0.38\n",
      "[37/50] [400/466] mean_loss : 0.38\n",
      "[38/50] [000/466] mean_loss : 0.33\n",
      "[38/50] [200/466] mean_loss : 0.37\n",
      "[38/50] [400/466] mean_loss : 0.37\n",
      "[39/50] [000/466] mean_loss : 0.39\n",
      "[39/50] [200/466] mean_loss : 0.37\n",
      "[39/50] [400/466] mean_loss : 0.37\n",
      "[40/50] [000/466] mean_loss : 0.41\n",
      "[40/50] [200/466] mean_loss : 0.36\n",
      "[40/50] [400/466] mean_loss : 0.36\n",
      "[41/50] [000/466] mean_loss : 0.31\n",
      "[41/50] [200/466] mean_loss : 0.36\n",
      "[41/50] [400/466] mean_loss : 0.36\n",
      "[42/50] [000/466] mean_loss : 0.30\n",
      "[42/50] [200/466] mean_loss : 0.35\n",
      "[42/50] [400/466] mean_loss : 0.35\n",
      "[43/50] [000/466] mean_loss : 0.23\n",
      "[43/50] [200/466] mean_loss : 0.35\n",
      "[43/50] [400/466] mean_loss : 0.34\n",
      "[44/50] [000/466] mean_loss : 0.31\n",
      "[44/50] [200/466] mean_loss : 0.34\n",
      "[44/50] [400/466] mean_loss : 0.35\n",
      "[45/50] [000/466] mean_loss : 0.25\n",
      "[45/50] [200/466] mean_loss : 0.33\n",
      "[45/50] [400/466] mean_loss : 0.35\n",
      "[46/50] [000/466] mean_loss : 0.47\n",
      "[46/50] [200/466] mean_loss : 0.33\n",
      "[46/50] [400/466] mean_loss : 0.34\n",
      "[47/50] [000/466] mean_loss : 0.43\n",
      "[47/50] [200/466] mean_loss : 0.33\n",
      "[47/50] [400/466] mean_loss : 0.33\n",
      "[48/50] [000/466] mean_loss : 0.30\n",
      "[48/50] [200/466] mean_loss : 0.33\n",
      "[48/50] [400/466] mean_loss : 0.33\n",
      "[49/50] [000/466] mean_loss : 0.39\n",
      "[49/50] [200/466] mean_loss : 0.33\n",
      "[49/50] [400/466] mean_loss : 0.32\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCH):\n",
    "    losses=[]\n",
    "    for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
    "        inputs, targets, input_lengths, target_lengths = pad_to_batch(batch, source2index, target2index)\n",
    "        \n",
    "        input_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data)))) for t in inputs]).view(inputs.size(0), -1)\n",
    "        start_decode = Variable(LongTensor([[target2index['<s>']] * targets.size(0)])).transpose(0, 1)\n",
    "        encoder.zero_grad()\n",
    "        decoder.zero_grad()\n",
    "        output, hidden_c = encoder(inputs, input_lengths)\n",
    "        \n",
    "        preds = decoder(start_decode, hidden_c, targets.size(1), output, input_masks, True)\n",
    "                                \n",
    "        loss = loss_function(preds, targets.view(-1))\n",
    "        losses.append(loss.data.tolist()[0] )\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm(encoder.parameters(), 50.0) # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm(decoder.parameters(), 50.0) # gradient clipping\n",
    "        enc_optimizer.step()\n",
    "        dec_optimizer.step()\n",
    "\n",
    "        if i % 200==0:\n",
    "            print(\"[%02d/%d] [%03d/%d] mean_loss : %0.2f\" %(epoch, EPOCH, i, len(train_data)//BATCH_SIZE, np.mean(losses)))\n",
    "            losses=[]\n",
    "\n",
    "    # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n",
    "    if RESCHEDULED == False and epoch  == EPOCH//2:\n",
    "        LR *= 0.01\n",
    "        enc_optimizer = optim.Adam(encoder.parameters(), lr=LR)\n",
    "        dec_optimizer = optim.Adam(decoder.parameters(), lr=LR * DECODER_LEARNING_RATIO)\n",
    "        RESCHEDULED = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# borrowed code from https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb\n",
    "\n",
    "def show_attention(input_words, output_words, attentions):\n",
    "    # Set up figure with colorbar\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111)\n",
    "    cax = ax.matshow(attentions.numpy(), cmap='bone')\n",
    "    fig.colorbar(cax)\n",
    "\n",
    "    # Set up axes\n",
    "    ax.set_xticklabels([''] + input_words, rotation=90)\n",
    "    ax.set_yticklabels([''] + output_words)\n",
    "\n",
    "    # Show label at every tick\n",
    "    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "\n",
    "#     show_plot_visdom()\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Source :  he is telling a lie .\n",
      "Truth :  il dit un mensonge .\n",
      "Prediction :  il dit un mensonge .\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEUCAYAAADOaUa5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGHtJREFUeJzt3X+4XVV95/H3J0FEIAJjoE4JGNQgzYMWnBic+pMWbKQK\nM4/WgqUFhaYdhQHRPmLrA31oqZVOnYrF1qiIMlpECzXWKBYFaalAEqBgMmAzsUqCBa6liFIC5Hzm\nj72vHm5z7znXfe7d+677eeXZzz17n3XX+d6T5HvXWWvttWSbiIgo04K2A4iIiJmTJB8RUbAk+YiI\ngiXJR0QULEk+IqJgSfIREQVLko+IKFiSfEREwZLkIyIKtlvbAUT5JJ2zi8sPARtt3z7b8UTMJ8qy\nBjHTJH0KWAF8vr70GuAOYCnwGdsXtRRaRPGS5GPGSboBOM72D+rzvYEvAKuoWvPL24wvomTpk4/Z\ncACwo+/8ceCnbP/7hOsRMWLpk4/Z8EngZkmfq89fC3xK0l7A5vbCiihfumtiVkh6EfBz9emNtje0\nGU/EfJEkH7NC0kLgp+j79Gj7O+1FFHOVpGcC9znJayjprokZJ+lM4HzgPmAnIMDAC9qMK+YeSfsB\nW4GTgM8NKB6kJR+zQNIW4Cjb32s7lpjbJJ0BHAsssP3atuOZCzK7JmbDPVQ3P0U09SbgDOAgSf+5\n7WDmgnTXxGzYClwv6Qv0TZm0/b72Qoq5RtIKYMz2PZI+AZwKvKfdqLovLfmYDd8B/hbYHVjUd0RM\nx2nAR+vHlwO/1mIsc0b65COi8yTtCWwCDrX9eH3tauD9tq9vM7auS5KPGSPpT22fLenzVLNpnsT2\n8S2EFXOQpKcA+9m+v+/a0wFsf7+1wOaA9MnHTLq8/vq/Wo0i5jzbj0v6oaQFtnuSDgUOA77Ydmxd\nl5Z8RMwJkjYCLwP2A24E1gOP2f7VVgPruLTkY8ZIupNddNOMs52boWI6ZPsRSacBH7R9kaTsRzBA\nknzMpNe0HUAURZL+K/CrVDNtABa2GM+ckCQfM8b2t9uOIYpyNvAu4GrbmyQ9G7iu5Zg6L33yI1AP\nAv051Rrph0t6AXC87T9oObRWSXqYH3fXqP7q+rFtP72VwDpM0rOAZbavlfQ0YDfbD7cdV5skvQv4\nku3b2o5lLsrNUKPxYaoWxuMAtu8ATmw1og6wvcj20+tjUd/5oiT4/0jSbwCfBT5UX1oC/HV7EXXG\nVuAsSbdJukzSr9QLlcUQ0l0zGnvavkVS/7Un2gqmiyS9lKqF+jFJi4FFtr/Vdlwd81ZgJXAzgO1/\nknRAuyG1z/angU8DSDqSatvIq+rlq6+lauXf0mKInZaW/GiMSXoOddeEpNcD3203pO6QdD7wTqpP\nO1Atb/B/2ouos3bYfmz8RNJuTDE7aT6yfZvt99g+mmpgfxNwesthdVqS/Gi8leoj9mGStlMNEP1W\nmwFJekm9vR6STpb0vrq/tw3/HTge+CGA7XvJ2jW78jVJvwM8TdKxwGeAz7ccUydI2lPSz064vC9w\nk+3VbcQ0VyTJj8Z24GPAhcAVVItxndJqRNVA8CP1f4y3A/8P+ERLsTxW7+Iz/klnr5bieBJJ+0la\nKenl40fLIZ0LPADcCfwmsA54d6sRdcfjVF00/f92PgJkueEB0ic/Gp8D/g24Fbi35VjGPWHbkk4A\n/sz2R+ubSNpwpaQPAfvWg4tvphqsbo2k04GzqAY3bwdeDHwd+Pm2YrLdo3pfWn1vuqhe1uBq4A3A\nxyQdDOyfvYIHS5IfjSW2V7UdxAQP11PPTgZeLmkB8JSWYtmfatbI94HnAecBx7QUy7izgBdRfdw/\nWtJhwB+2EYikK22/YbI7hHNn8I98BFhD9an51+uvMUCS/Gj8g6Tn276z7UD6/ArwRuA02/9St3z+\nuKVYjrX9TqpuLAAk/QnVYGxbHrX9qCQkPdX2XZKe11IsZ9Vfc4fwFOq/I9X3pZxItY5NDJCboRro\na3ntBiyjms+7gx/f7DOvW2CS/gfwFuDZVGMC4xYBN9o+uZXA+NFa5G+iGiT/eeBB4Cm2j2srprlA\n0jNt/0uLr38qVXffdtsntRXHXJIk38Cg2Spt3NYv6e9tv3TC3abQwl2mkvahWjHwPVSDiuMetv2v\nsxXHIJJeAexDNd/6sUHlZ+D1J/5d/egpOnZnsKQv2P6lFl9/T6rpya+zfW1bccwlSfIREQXLFMqI\niIIlyc8ASZ26OSPxTK1r8UD3Yko8s0PSpZLul/SNSZ6XpIslbZF0h6QXDqozSX5mdO0fYOKZWtfi\nge7FlHhmx2VUa/NM5tVUkzyWUb0Hfz6owiT5iIiOsH0DMNWkhBOAT7hyE9UNhlPe9Zt58rXFixd7\n6dKlI6nr4IMPZsWKFY1GtDdu3DiSWMZJ6tQIe+IZrGsxFRrPmO39m1SwatUqj42NDVV248aNm4BH\n+y6tsb1mGi93IHBP3/m2+tqkCyImydeWLl3Khg3duUN6wrLFETEzGk9zHhsbGzp3SHrU9oqmrzkd\nSfIREQ3N4lT07cBBfedL6muTSp98REQDBnb2ekMdI7AW+PV6ls2LgYdsT7l3RVryERGNGI9obxdJ\nfwm8ElgsaRtwPvXCgrb/gmr56eOALcAjVEtzTClJPiKiCUNvRL01g9bjqfdleOt06kySj4hoqMvL\nwyTJR0Q0YKCXJB8RUa605CMiCmV7VDNnZkSSfEREQ2nJR0QUbFRTKGdCknxERAPVwGvbUUwuST4i\noqF010RElKrjA69Fr10j6R/qr0sn22klIqIJU7XkhznaUHRL3vbPtR1DRJQvN0O1RNIPbO/ddhwR\nUbYu98kX3V0ziKTVkjZI2vDAAw+0HU5EzEke+k8b5nWSt73G9grbK/bfv9EOYBExT7lehXKYow1F\nd9dERMyGXodn1yTJR0Q0kFUoIyIK1+WB16KT/PjMGtv/DBzebjQRUSQ7LfmIiJKlJR8RUSgDO5Pk\nIyLKlZZ8RETBkuQjIgrlDLxGRJQtLfmIiIIlyUdEFKqaXZNlDSIiipU9XiMiStXirk/DSJKPiGhg\nfPu/rkqSj4hoKFMo54CNGzciqe0wOquLLZX8fUVXdPH/x7gk+YiIBmyzM5uGRESUq639W4eRJB8R\n0VCXp1DO6428IyKaGp9dM8wxiKRVku6WtEXSubt4/mBJ10m6TdIdko4bVGeSfEREQ6NI8pIWApcA\nrwaWAydJWj6h2LuBK20fCZwIfHBQbOmuiYhoYnQDryuBLba3Aki6AjgB2Nz/asDT68f7APcOqjRJ\nPiKigRHeDHUgcE/f+TbgqAllfg/4sqQzgb2AYwZVmu6aiIiGevWa8oMOYLGkDX3H6mm+1EnAZbaX\nAMcBl0uaMo+nJR8R0dA0plCO2V4xyXPbgYP6zpfU1/qdBqwCsP11SXsAi4H7J3vBtOQjIhqyhzsG\nWA8sk3SIpN2pBlbXTijzHeAXACT9DLAH8MBUlaYlHxHRgBnN2jW2n5B0BnANsBC41PYmSRcAG2yv\nBd4OfFjS2+qXPtUDBgSS5CMimhjhsga21wHrJlw7r+/xZuAl06kzST4iooEsNRwRUbgk+Vki6feA\nH1DdLHCD7WslnQ2ssf1Iq8FFRLG6vJ58kbNrbJ9n+9r69GxgzzbjiYiSeeg/bZjzLXlJvwucQjVP\n9B5go6TLgL8Bfro+rpM0Zvvo1gKNiCINOT2yNXM6yUv6L1RzSY+g+lluBTaOP2/7YknnAEfbHmsn\nyogoXTYNmTkvA64e72+XNPHGgSnVtxRP97biiIgfGdU8+Zky15N8I7bXAGsAJHX3bykiOq3Ls2vm\n+sDrDcB/k/Q0SYuA1+6izMPAotkNKyLmjSHXkm/rF8GcbsnbvlXSp4F/pBp4Xb+LYmuAL0m6NwOv\nETEjOtySn9NJHsD2hcCFUzz/AeADsxdRRMw3vZ1J8hERRaqmUCbJR0QUK0k+IqJY7Q2qDiNJPiKi\nIfeS5CMiipQ++YiIwjnLGkRElKvDDfkk+YiIRuz0yUdElCx98hERhcoerxERhUuSj4golY13ZnZN\nRESx0pKPOU9S2yH8B137j9XF9yhmR8f+KT5JknxERAMZeI2IKFmWNYiIKJnpZeA1IqJcaclHRBQq\nq1BGRJQuST4iolzubpd8knxERFPpromIKJVNL5uGRESUqes3Qy1oO4CIiDnN1UbewxyDSFol6W5J\nWySdO0mZN0jaLGmTpE8NqjMt+YiIpkbQkpe0ELgEOBbYBqyXtNb25r4yy4B3AS+x/aCkAwbVm5Z8\nREQjxh7uGGAlsMX2VtuPAVcAJ0wo8xvAJbYfBLB9/6BKk+QjIhrq9TzUASyWtKHvWN1XzYHAPX3n\n2+pr/Q4FDpV0o6SbJK0aFFu6ayIiGnDdJz+kMdsrGrzcbsAy4JXAEuAGSc+3/W9TfcOcJmkp8De2\nD6/P3wHsTfUm3AwcDewLnGb779qJMiJKNqLZNduBg/rOl9TX+m0Dbrb9OPAtSd+kSvrrJ6u09O6a\n3WyvBM4Gzm87mIgo04j65NcDyyQdIml34ERg7YQyf03VgEXSYqrum61TVTrnW/IDXFV/3Qgsnfhk\n3R+2euL1iIjhDZXAB9diPyHpDOAaYCFwqe1Nki4ANtheWz/3KkmbgZ3Ab9v+3lT1lpDkn+DJn0j2\n6Hu8o/66k138rLbXAGsAJHX3boaI6K4RrkJpex2wbsK18/oeGzinPoZSQnfNfcABkp4h6anAa9oO\nKCLmDwPe6aGONsz5lrztx+uPM7dQDVLc1XJIETHPdHlZgzmf5AFsXwxcPMXzY+yiTz4iorHhBlVb\nU0SSj4ho0zTmyc+6JPmIiIbSko+IKFTXlxpOko+IaMLG2TQkIqJc2eM1IqJg6a6JiCjVCO94nQlJ\n8hERDWTgNSKiaKa3s7ud8knyERFNpLsmIqJwSfIREeXqcI5Pko+IaCIDrxEzRFLbITxJ1/6jd+39\nKdb0NvKedUnyERGNmF6WNYiIKFfXPsX1S5KPiGgqST4iokxOn3xERNk63JBPko+IaCZ7vEZElMtk\ndk1ERKlM+uQjIoqW7pqIiGK50yOvSfIREU1kqeGIiLL1dibJR0QUKatQRkSULN01EREl6/bNUAsG\nFZC0VNJdki6T9E1Jn5R0jKQbJf2TpJWS9pJ0qaRbJN0m6YT6e0+VdJWkL9VlL6qvL6zr+4akOyW9\nrb5+hKSbJN0h6WpJ+9XXr5f03rr+b0p6WX19T0lXStpcl79Z0or6uVdJ+rqkWyV9RtLeM/UmRsT8\nZnuoow3DtuSfC/wy8GZgPfBG4KXA8cDvAJuBr9p+s6R9gVskXVt/7xHAkcAO4G5JHwAOAA60fThA\n/T0AnwDOtP01SRcA5wNnj8dqe6Wk4+rrxwBvAR60vVzS4cDtdX2LgXcDx9j+oaR3AucAF0zz/YmI\nGKjLN0MNbMnXvmX7Tts9YBPwFVe/lu4ElgKvAs6VdDtwPbAHcHD9vV+x/ZDtR6l+GTwL2Ao8W9IH\nJK0Cvi9pH2Bf21+rv+/jwMv7Yriq/rqxfk2oftFcAWD7G8Ad9fUXA8uBG+uYTqlf90kkrZa0QdKG\nId+HiIgnGV+FcphjEEmrJN0taYukc6co9zpJHu+5mMqwLfkdfY97fee9uo6dwOts3z0hkKMmfO9O\nqhb5g5J+FvhF4LeANwBvGzKGnUPELeBvbZ80VSHba4A1dazd/VUcEZ02iq4YSQuBS4BjgW3Aeklr\nbW+eUG4RcBZw8zD1DtuSH+Qa4EzVm0pKOnKqwnV3ygLbf0XVrfJC2w8BD473twO/BnxtsjpqN1L9\ngkDScuD59fWbgJdIem793F6SDp3+jxURMchw/fFD/CJYCWyxvdX2Y1S9FCfsotzvA+8FHh0mulHN\nrvl94E+BOyQtAL4FvGaK8gcCH6vLAryr/noK8BeS9qTq0nnTgNf9IPBxSZuBu6i6kh6y/YCkU4G/\nlPTUuuy7gW9O78eKiBhgdJuGHAjc03e+DTiqv4CkFwIH2f6CpN8eptKBSd72PwOH952fOslzv7mL\n770MuKzvvD/xv3AX5W+n6k+feP2VfY/H+HGf/KPAybYflfQc4Frg23W5rwIvmuJHi4gYiWl01yye\nMAa4pu42HqhuFL8POHU6sc31efJ7AtdJegpVP/xb6o85ERGzYpp3vI7ZnmywdDtwUN/5kvrauEVU\njerr657xZwJrJR1ve9LJI3M6ydt+GBg4uhwRMXOMR7NpyHpgmaRDqJL7iVTT1atXqcYtF4+fS7oe\neMdUCR5GN/AaETE/Gdwb7piyGvsJ4AyqiSz/F7jS9iZJF0g6/icNb0635CMiumBUd7PaXgesm3Dt\nvEnKvnKYOpPkIyIa6vLaNUnyERENZKnhiIiS2fR2jmTgdUYkyUdENJWWfEREuUySfEREkZydoSIi\nSmY8aBJ8i5LkIyIaSks+IqJgvdEsazAjkuQjIhqo1opPko+IKFe6ayIiypUplBERBcvAa0REsUyv\nt7PtICaVJB8R0UBuhoqIKFySfEREwZLkIyKK5UyhjIgomcnNUBERRbKzrEFERMGcPvmIiJJl7ZqI\niIKlJR8RUbAk+YiIUjlTKCMiimWg56xdExFRqMyu6SxJq4HVbccREXNbknxH2V4DrAGQ1N2/pYjo\ntCT5iIhCVeOumScfEVEo4w4va7Cg7QBmg6R1kn667Tgiokwe8k8b5kVL3vZxbccQEeVKn3xERLGc\nPvmIiFJ1fY/XedEnHxExk2wPdQwiaZWkuyVtkXTuLp4/R9JmSXdI+oqkZw2qM0k+IqKhXq831DEV\nSQuBS4BXA8uBkyQtn1DsNmCF7RcAnwUuGhRbknxERCMG94Y7prYS2GJ7q+3HgCuAE570SvZ1th+p\nT28ClgyqNEk+IqKhaUyhXCxpQ9/Rv6zKgcA9fefb6muTOQ344qDYMvAaEdHANAdex2yvaPqakk4G\nVgCvGFQ2ST4ioqERza7ZDhzUd76kvvYkko4Bfhd4he0dgypNko+IaGRk8+TXA8skHUKV3E8E3thf\nQNKRwIeAVbbvH6bSJPmIiIYGzZwZhu0nJJ0BXAMsBC61vUnSBcAG22uBPwb2Bj4jCeA7to+fqt4k\n+YiIBkZ5M5TtdcC6CdfO63t8zHTrTJKPiGgke7xGRBTNZO2aiIhidXntmiT5iIhGPJKB15mSJB8R\n0UC2/4uIKFy6ayIiCpYkHxFRrEyhjIgoWlubdA8jST4iogEber2dbYcxqST5iIhGhtvary1J8hER\nDSXJR0QULEk+IqJguRkqIqJUzhTKiIhiGeilJT86kk4EnmP7wrZjiYiAbnfXLGg7gEEk7S5pr75L\nrwa+NGTZiIgZVk2hHOZoQ2eTvKSfkfQnwN3AofU1AUcAt0p6haTb6+M2SYuA/YBNkj4k6UXtRR8R\n80mS/JAk7SXpTZL+HvgwsBl4ge3b6iJHAv/o6t16B/BW20cALwP+3fZ9wPOA64AL6+T/PyX9p9n/\naSJiPhjf47WrSb5rffLfBe4ATrd91y6eXwV8sX58I/A+SZ8ErrK9DcD2DuAK4ApJBwN/Blwk6dm2\n7+2vTNJqYPXM/CgRMT8Yd3hZg0615IHXA9uBqySdJ+lZE55/FfBlANt/BJwOPA24UdJh44UkHSDp\n7cDngYXAG4H7Jr6Y7TW2V9heMSM/TUTMCx7yTxs61ZK3/WXgy5KeAZwMfE7SGFUyfxDYzfb3ACQ9\nx/adwJ11//thkr4LfBw4DLgcOM729jZ+loiYP3LH6zTVifz9wPslrQR2AscC1/YVO1vS0UAP2ETV\njbMHcDFwnbv8rkdEUbqcbjqZ5PvZvgVA0vnAR/qun7mL4juAr85SaBER9aBqd+fJdz7Jj7N9etsx\nRETsSlryEREF6/XSko+IKFda8hERpTImLfmIiCKN3/HaVUnyERENJclHRBQsST4iolim1+G1a5Lk\nIyIa6HqffNcWKIuImHvG93kddAwgaZWkuyVtkXTuLp5/qqRP18/fLGnpoDqT5CMiGhl2Dcqpk7yk\nhcAlVLvfLQdOkrR8QrHTgAdtPxf438B7B0WXJB8R0ZDdG+oYYCWwxfZW249R7YtxwoQyJ1CttAvw\nWeAX6h3zJpU++YiIhka0rMGBwD1959uAoyYrY/sJSQ8BzwDGJqs0Sf7HxoBvj6iuxUzxprcg8Uxt\nJPEMaFBNV5Hv0QiNKp6JGxP9JK6himcYe0ja0He+xvaaEcQwqST5mu39R1WXpA1d2m0q8Uyta/FA\n92JKPJOzvWpEVW0HDuo7X1Jf21WZbZJ2A/YBvjdVpemTj4johvXAMkmHSNodOBFYO6HMWuCU+vHr\nga8O2iApLfmIiA6o+9jPoOr+WQhcanuTpAuADbbXAh8FLpe0BfhXql8EU0qSnxkz2sf2E0g8U+ta\nPNC9mBLPLLC9Dlg34dp5fY8fBX55OnWqy3dqRUREM+mTj4goWJJ8RETBkuQjIgqWJB8RUbAk+YiI\ngiXJR0QULEk+IqJg/x/AvR9vGPeJlAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f215c14de48>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test = random.choice(train_data)\n",
    "input_ = test[0]\n",
    "truth = test[1]\n",
    "\n",
    "output, hidden = encoder(input_, [input_.size(1)])\n",
    "pred, attn = decoder.decode(hidden, output)\n",
    "\n",
    "input_ = [index2source[i] for i in input_.data.tolist()[0]]\n",
    "pred = [index2target[i] for i in pred.data.tolist()]\n",
    "\n",
    "\n",
    "print('Source : ',' '.join([i for i in input_ if i not in ['</s>']]))\n",
    "print('Truth : ',' '.join([index2target[i] for i in truth.data.tolist()[0] if i not in [2, 3]]))\n",
    "print('Prediction : ',' '.join([i for i in pred if i not in ['</s>']]))\n",
    "\n",
    "if USE_CUDA:\n",
    "    attn = attn.cpu()\n",
    "\n",
    "show_attention(input_, pred, attn.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# TODO "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* BLEU\n",
    "* Beam Search\n",
    "* <a href=\"http://www.aclweb.org/anthology/P15-1001\">Sampled Softmax</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Further topics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* <a href=\"https://s3.amazonaws.com/fairseq/papers/convolutional-sequence-to-sequence-learning.pdf\">Convolutional Sequence to Sequence learning</a>\n",
    "* <a href=\"https://arxiv.org/abs/1706.03762\">Attention is all you need</a>\n",
    "* <a href=\"https://arxiv.org/abs/1711.00043\">Unsupervised Machine Translation Using Monolingual Corpora Only</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Suggested Reading "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* <a href=\"https://arxiv.org/pdf/1709.07809.pdf\">SMT chapter13. Neural Machine Translation</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
